#!/usr/bin/env python3

import json
import os
import socket
import uuid

CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1")
CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000"))
CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default")
CLIENT_NAME = "simple native protocol"


def writeVarUInt(x, ba):
    for _ in range(0, 9):
        byte = x & 0x7F
        if x > 0x7F:
            byte |= 0x80

        ba.append(byte)

        x >>= 7
        if x == 0:
            return


def writeStringBinary(s, ba):
    b = bytes(s, "utf-8")
    writeVarUInt(len(s), ba)
    ba.extend(b)


def readStrict(s, size=1):
    res = bytearray()
    while size:
        cur = s.recv(size)
        # if not res:
        #     raise "Socket is closed"
        size -= len(cur)
        res.extend(cur)

    return res


def readUInt(s, size=1):
    res = readStrict(s, size)
    val = 0
    for i in range(len(res)):
        val += res[i] << (i * 8)
    return val


def readUInt8(s):
    return readUInt(s)


def readUInt16(s):
    return readUInt(s, 2)


def readUInt32(s):
    return readUInt(s, 4)


def readUInt64(s):
    return readUInt(s, 8)


def readVarUInt(s):
    x = 0
    for i in range(9):
        byte = readStrict(s)[0]
        x |= (byte & 0x7F) << (7 * i)

        if not byte & 0x80:
            return x

    return x


def readStringBinary(s):
    size = readVarUInt(s)
    s = readStrict(s, size)
    return s.decode("utf-8")


def sendHello(s):
    ba = bytearray()
    writeVarUInt(0, ba)  # Hello
    writeStringBinary(CLIENT_NAME, ba)
    writeVarUInt(21, ba)
    writeVarUInt(9, ba)
    writeVarUInt(54449, ba)
    writeStringBinary(CLICKHOUSE_DATABASE, ba)  # database
    writeStringBinary("default", ba)  # user
    writeStringBinary("", ba)  # pwd
    s.sendall(ba)


def receiveHello(s):
    p_type = readVarUInt(s)
    assert p_type == 0  # Hello
    server_name = readStringBinary(s)
    # print("Server name: ", server_name)
    server_version_major = readVarUInt(s)
    # print("Major: ", server_version_major)
    server_version_minor = readVarUInt(s)
    # print("Minor: ", server_version_minor)
    server_revision = readVarUInt(s)
    # print("Revision: ", server_revision)
    server_timezone = readStringBinary(s)
    # print("Timezone: ", server_timezone)
    server_display_name = readStringBinary(s)
    # print("Display name: ", server_display_name)
    server_version_patch = readVarUInt(s)
    # print("Version patch: ", server_version_patch)


def serializeClientInfo(ba, query_id):
    writeStringBinary("default", ba)  # initial_user
    writeStringBinary(query_id, ba)  # initial_query_id
    writeStringBinary("127.0.0.1:9000", ba)  # initial_address
    ba.extend([0] * 8)  # initial_query_start_time_microseconds
    ba.append(1)  # TCP
    writeStringBinary("os_user", ba)  # os_user
    writeStringBinary("client_hostname", ba)  # client_hostname
    writeStringBinary(CLIENT_NAME, ba)  # client_name
    writeVarUInt(21, ba)
    writeVarUInt(9, ba)
    writeVarUInt(54449, ba)
    writeStringBinary("", ba)  # quota_key
    writeVarUInt(0, ba)  # distributed_depth
    writeVarUInt(1, ba)  # client_version_patch
    ba.append(0)  # No telemetry


def sendQuery(s, query):
    ba = bytearray()
    query_id = uuid.uuid4().hex
    writeVarUInt(1, ba)  # query
    writeStringBinary(query_id, ba)

    ba.append(1)  # INITIAL_QUERY

    # client info
    serializeClientInfo(ba, query_id)

    writeStringBinary("", ba)  # No settings
    writeStringBinary("", ba)  # No interserver secret
    writeVarUInt(2, ba)  # Stage - Complete
    ba.append(0)  # No compression
    writeStringBinary(query, ba)  # query, finally
    s.sendall(ba)


def serializeBlockInfo(ba):
    writeVarUInt(1, ba)  # 1
    ba.append(0)  # is_overflows
    writeVarUInt(2, ba)  # 2
    writeVarUInt(0, ba)  # 0
    ba.extend([0] * 4)  # bucket_num


def sendEmptyBlock(s):
    ba = bytearray()
    writeVarUInt(2, ba)  # Data
    writeStringBinary("", ba)
    serializeBlockInfo(ba)
    writeVarUInt(0, ba)  # rows
    writeVarUInt(0, ba)  # columns
    s.sendall(ba)


def assertPacket(packet, expected):
    assert packet == expected, packet


class Progress:
    def __init__(self):
        # NOTE: this is done in ctor to initialize __dict__
        self.read_rows = 0
        self.read_bytes = 0
        self.total_rows_to_read = 0
        self.written_rows = 0
        self.written_bytes = 0

    def __str__(self):
        return json.dumps(self.__dict__)

    def __add__(self, b):
        self.read_rows += b.read_rows
        self.read_bytes += b.read_bytes
        self.total_rows_to_read += b.total_rows_to_read
        self.written_rows += b.written_rows
        self.written_bytes += b.written_bytes
        return self

    def readPacket(self, s):
        self.read_rows += readVarUInt(s)
        self.read_bytes += readVarUInt(s)
        self.total_rows_to_read += readVarUInt(s)
        self.written_rows += readVarUInt(s)
        self.written_bytes += readVarUInt(s)

    def __bool__(self):
        return (
            self.read_rows > 0
            or self.read_bytes > 0
            or self.total_rows_to_read > 0
            or self.written_rows > 0
            or self.written_bytes > 0
        )


def readProgress(s):
    packet_type = readVarUInt(s)
    if packet_type == 2:  # Exception
        raise RuntimeError(readException(s))

    if packet_type == 5:  # End stream
        return None

    assertPacket(packet_type, 3)  # Progress

    progress = Progress()
    progress.readPacket(s)
    return progress


def readException(s):
    code = readUInt32(s)
    name = readStringBinary(s)
    text = readStringBinary(s)
    readStringBinary(s)  # trace
    assertPacket(readUInt8(s), 0)  # has_nested
    return "code {}: {}".format(code, text.replace("DB::Exception:", ""))


def main():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(30)
        s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
        sendHello(s)
        receiveHello(s)
        # For 1 second sleep and 1000ms of interactive_delay we definitelly should have non zero progress packet.
        # NOTE: interactive_delay=0 cannot be used since in this case CompletedPipelineExecutor will not call cancelled callback.
        sendQuery(
            s,
            "insert into function null('_ Int') select sleep(1) from numbers(2) settings max_block_size=1, interactive_delay=1000",
        )

        # external tables
        sendEmptyBlock(s)

        summary_progress = Progress()
        non_empty_progress_packets = 0
        while True:
            progress = readProgress(s)
            if progress is None:
                break
            summary_progress += progress
            if progress:
                non_empty_progress_packets += 1

        print(summary_progress)
        # Print only non empty progress packets, eventually we should have at least 3 of them
        # - 2 for each INSERT block (one of them can be merged with read block, heance 3 or for)
        # - 1 or 2 for each SELECT block
        assert non_empty_progress_packets in (3, 4), f"{non_empty_progress_packets=:}"

        s.close()


if __name__ == "__main__":
    main()
